#!/bin/bash

declare -r GNMI_PORT=8080 # Default GNMI port
declare -r MODULE_REBOOT_DPU="DPU"
declare -r MODULE_REBOOT_SMARTSWITCH="SMARTSWITCH"

declare -r EXIT_DPU_DOWN=2

# Function to print debug message
function log_message() {
    local message=$1
    echo "$(date '+%Y-%m-%d %H:%M:%S') - $message" >&2
}

# Function to check if running on smart switch
function is_smartswitch()
{
    python3 -c "from utilities_common.chassis import is_smartswitch; print(is_smartswitch())" | grep "True"
}

# Function to check if running on DPU
function is_dpu()
{
    python3 -c "from utilities_common.chassis import is_dpu; print(is_dpu())" | grep "True"
}

# Function to retrieve number of DPUs
function get_num_dpus()
{
    python3 -c "from utilities_common.chassis import get_num_dpus; print(get_num_dpus())"
}

# Function to retrieve DPU IP from CONFIG_DB
function get_dpu_ip()
{
    local DPU_NAME=$1
    sonic-db-cli CONFIG_DB HGET "DHCP_SERVER_IPV4_PORT|bridge-midplane|${DPU_NAME}" "ips@"
}

# Function to retrieve GNMI port from CONFIG_DB
function get_gnmi_port()
{
    local DPU_NAME=$1
    sonic-db-cli CONFIG_DB HGET "DPU_PORT|$DPU_NAME" "gnmi_port"
}

# Function to get reboot status from DPU
function get_reboot_status()
{
    local dpu_ip=$1
    local port=$2
    $(docker exec gnmi gnoi_client -target "${dpu_ip}:${port}" -logtostderr -notls -module System -rpc RebootStatus &>/dev/null)
    if [ $? -ne 0 ]; then
        return ${EXIT_ERROR}
    fi
    local is_reboot_active
    is_reboot_active=$(echo "$reboot_status" | grep "active" | awk '{print $2}')
    if [ "$is_reboot_active" == "false" ]; then
        return ${EXIT_SUCCESS}
    fi
    return ${EXIT_ERROR}
}

# Function to detach PCI module
function pci_detach_module()
{
    local DPU_NAME=$1
    local DPU_BUS_INFO=$2
    python3 -c "from utilities_common.module import ModuleHelper; helper = ModuleHelper(); helper.pci_detach_module('${DPU_NAME}')"
    if [ $? -ne 0 ]; then
        log_message "ERROR: PCI detach vendor API is not available"
        echo 1 > /sys/bus/pci/devices/${DPU_BUS_INFO}/remove
    fi
}

# Function to rescan PCI module
function pci_reattach_module()
{
    local DPU_NAME=$1
    local DPU_BUS_INFO=$2
    python3 -c "from utilities_common.module import ModuleHelper; helper = ModuleHelper(); helper.pci_reattach_module('${DPU_NAME}')"
    if [ $? -ne 0 ]; then
        log_message "ERROR: PCI reattach vendor API is not available"
        echo 1 > /sys/bus/pci/devices/${DPU_BUS_INFO}/rescan
    fi
}

# Function to reboot DPU
function reboot_dpu_platform()
{
    local DPU_NAME=$1
    local REBOOT_TYPE=$2
    python3 -c "from utilities_common.module import ModuleHelper; helper = ModuleHelper(); helper.reboot_module('${DPU_NAME}', '${REBOOT_TYPE}')"
    log_message "INFO: Rebooting ${DPU_NAME} with reboot_type:${REBOOT_TYPE}..."
}

# Function to wait for DPU reboot status
function wait_for_dpu_reboot_status()
{
    local dpu_ip=$1
    local port=$2
    local DPU_NAME=$3

    if [[ -z "$PLATFORM_JSON_PATH" ]]; then
        log_message "ERROR: PLATFORM_JSON_PATH is not defined"
        exit $EXIT_ERROR
    fi

    local dpu_halt_services_timeout=$(jq -r '.dpu_halt_services_timeout' "$PLATFORM_JSON_PATH" 2>/dev/null)
    if [ -z "$dpu_halt_services_timeout" ] || [ "$dpu_halt_services_timeout" == "null" ]; then
        # Default timeout
        dpu_halt_services_timeout=60
    fi

    local poll_interval=5
    local waited_time=0
    while true; do
        local reboot_status
        get_reboot_status "${dpu_ip}" "${port}"
        reboot_status=$?
        if [ $reboot_status -eq ${EXIT_SUCCESS} ]; then
            log_message "INFO: ${DPU_NAME} halted the services successfully"
            break
        fi

        sleep "$poll_interval"
        waited_time=$((waited_time + poll_interval))
        if [ $waited_time -ge $dpu_halt_services_timeout ]; then
            log_message "ERROR: Timeout waiting for ${DPU_NAME} to finish halting the services"
            return
        fi
    done
    return
}

# Function to send reboot command to DPU
function gnmi_reboot_dpu()
{
    # Retrieve DPU IP and GNMI port
    dpu_ip=$(get_dpu_ip "${DPU_NAME}")
    port=$(get_gnmi_port "${DPU_NAME}")
    if [ -z "$port" ]; then
        port=$GNMI_PORT # Default GNMI port
    fi
    log_message "INFO: Rebooting ${DPU_NAME}, ip:$dpu_ip gnmi_port:$port"

    if [ -z "$dpu_ip" ]; then
        log_message "ERROR: Failed to retrieve DPU IP for ${DPU_NAME}"
        return ${EXIT_ERROR}
    fi

    $(docker exec gnmi gnoi_client -target ${dpu_ip}:${port} -logtostderr -notls -module System -rpc Reboot -jsonin '{"method":3, "message":"User initiated reboot"}' &>/dev/null)
    if [ $? -ne 0 ]; then
        log_message "ERROR: Failed to send gnoi command to halt services on ${DPU_NAME}"
        log_message "ERROR: proceeding without halting the services"
    else
        # Wait for DPU to halt services, if reboot command is successful
        wait_for_dpu_reboot_status "${dpu_ip}" "${port}" "${DPU_NAME}"
    fi
}

function reboot_dpu()
{
    local DPU_NAME=$1
    local REBOOT_TYPE=$2
    local DPU_INDEX=${DPU_NAME//[!0-9]/}

    debug "User requested rebooting device ${DPU_NAME} ..."

    # Check if the DPU operation status is online before rebooting
    local oper_status
    oper_status=$(show chassis modules status "${DPU_NAME^^}" | sed -n '/^ *DPU/ s/.*\s\+\(Online\|Offline\)\s\+.*/\1/p')
    if [ $? -ne 0 ]; then
        log_message "ERROR: Failed to retrieve DPU status."
    else
        log_message "INFO: DPU ${DPU_NAME} is in '$oper_status' state before reboot."
        oper_status=$(echo "$oper_status" | tr '[:upper:]' '[:lower:]')
        if [ "$oper_status" != "online" ]; then
            log_message "INFO: ${DPU_NAME} is not online. Current status: $oper_status"
            return ${EXIT_DPU_DOWN}
        fi
    fi

    # Send reboot command to DPU
    gnmi_reboot_dpu "${DPU_NAME}"
    if [ $? -ne 0 ]; then
        log_message "ERROR: Failed to send gnoi command to reboot ${DPU_NAME}"
    fi

    local DPU_BUS_INFO=$(jq -r --arg DPU_NAME "$DPU_NAME" '.DPUS[$DPU_NAME].bus_info' "$PLATFORM_JSON_PATH")
    if [ -z "$DPU_BUS_INFO" ] || [ "$DPU_BUS_INFO" = "null" ]; then
        log_message "ERROR: Failed to retrieve bus info for ${DPU_NAME}"
        return ${EXIT_ERROR}
    fi

    pci_detach_module ${DPU_NAME} ${DPU_BUS_INFO}
    if [ $? -ne 0 ]; then
        log_message "ERROR: Failed to detach PCI module for ${DPU_NAME}"
        return ${EXIT_ERROR}
    fi

    reboot_dpu_platform ${DPU_NAME} ${REBOOT_TYPE}
    if [ $? -ne 0 ]; then
        log_message "ERROR: Failed to send platform command to reboot ${DPU_NAME}"
        return ${EXIT_ERROR}
    fi

    if [[ "$REBOOT_TYPE" != $MODULE_REBOOT_SMARTSWITCH ]]; then
        pci_reattach_module ${DPU_NAME} ${DPU_BUS_INFO}
    fi
}

# Function to reboot all DPUs in parallel
function reboot_all_dpus() {
    local NUM_DPU=$1

    if [[ -z $NUM_DPU ]]; then
        log_message "ERROR: Failed to retrieve number of DPUs or no DPUs found"
        return
    fi

    local failures=0
    for (( i=0; i<"$NUM_DPU"; i++ )); do
        reboot_dpu "dpu$i" "$MODULE_REBOOT_SMARTSWITCH" &
        if [ $? -ne 0 ]; then
            ((failures++))
        fi
    done
    wait
    return $failures
}

# Function to verify DPU module name
function verify_dpu_module_name() {
    local DPU_MODULE_NAME=$1
    local NUM_DPU=$2

    if [[ -z "$DPU_MODULE_NAME" ]]; then
        log_message "ERROR: DPU module name not provided"
        return $EXIT_ERROR
    fi

    NUM_DPU=$((NUM_DPU - 1))
    if [[ ! "$DPU_MODULE_NAME" =~ ^dpu[0-$NUM_DPU]$ ]]; then
        log_message "ERROR: Invalid DPU module name provided"
        return $EXIT_ERROR
    fi
}

# Function to handle scenarios on smart switch
function handle_smart_switch() {
    local REBOOT_DPU=$1
    local PRE_SHUTDOWN=$2
    local DPU_NAME=$3

    NUM_DPU=$(get_num_dpus)

    if is_dpu; then
        if [[ "$PRE_SHUTDOWN" != "yes" ]]; then
            log_message "ERROR: '-p' option not specified for a DPU"
            return $EXIT_ERROR
        elif [[ "$REBOOT_DPU" == "yes" ]]; then
            log_message "ERROR: '-d' option specified for a DPU"
            return $EXIT_ERROR
        fi
        return $EXIT_SUCCESS
    fi

    if [[ "$PRE_SHUTDOWN" == "yes" ]]; then
        log_message "ERROR: '-p' option specified for a non-DPU"
        return $EXIT_ERROR
    fi

    if [[ "$REBOOT_DPU" == "yes" ]]; then
        if is_smartswitch; then
            if [[ -z $NUM_DPU ]]; then
                log_message "ERROR: Failed to retrieve number of DPUs or no DPUs found"
                return $EXIT_ERROR
            fi

            DPU_MODULE_NAME="${DPU_NAME,,}"
            verify_dpu_module_name "$DPU_MODULE_NAME" "$NUM_DPU"
            result=$?
            if [[ $result -ne $EXIT_SUCCESS ]]; then
                return $result
            fi

            reboot_dpu "$DPU_MODULE_NAME" "$MODULE_REBOOT_DPU"
            result=$?
            return $result
        else
            log_message "ERROR: '-d' option specified for a non-smart-switch"
            return $EXIT_ERROR
        fi
    fi

    # If the system is a smart switch, reboot all DPUs in parallel
    if is_smartswitch; then
        reboot_all_dpus "$NUM_DPU" "$MODULE_REBOOT_SMARTSWITCH"
        result=$?
        return $result
    fi
}
